{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a character-level GPT on some text data\n",
    "\n",
    "The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some Shakespeare, which we'll get it to predict character-level."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up logging\n",
    "import logging\n",
    "logging.basicConfig(\n",
    "        format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n",
    "        datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "        level=logging.INFO,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make deterministic\n",
    "from mingpt.utils import set_seed\n",
    "set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing dataset.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile dataset.py\n",
    "\n",
    "import math\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "class CharDataset(Dataset):\n",
    "\n",
    "    def __init__(self, data, block_size):\n",
    "        chars = sorted(list(set(data)))\n",
    "        data_size, vocab_size = len(data), len(chars)\n",
    "        print('data has %d characters, %d unique.' % (data_size, vocab_size))\n",
    "        \n",
    "        self.stoi = { ch:i for i,ch in enumerate(chars) }\n",
    "        self.itos = { i:ch for i,ch in enumerate(chars) }\n",
    "        self.block_size = block_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.data = data\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data) - self.block_size\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # grab a chunk of (block_size + 1) characters from the data\n",
    "        chunk = self.data[idx:idx + self.block_size + 1]\n",
    "        # encode every character to an integer\n",
    "        dix = [self.stoi[s] for s in chunk]\n",
    "        \"\"\"\n",
    "        arrange data and targets so that the first i elements of x\n",
    "        will be asked to predict the i-th element of y. Notice that\n",
    "        the eventual language model will actually make block_size\n",
    "        individual predictions at the same time based on this data,\n",
    "        so we are being clever and amortizing the cost of the forward\n",
    "        pass of the network. So for example if block_size is 4, then\n",
    "        we could e.g. sample a chunk of text \"hello\", the integers in\n",
    "        x will correspond to \"hell\" and in y will be \"ello\". This will\n",
    "        then actually \"multitask\" 4 separate examples at the same time\n",
    "        in the language model:\n",
    "        - given just \"h\", please predict \"e\" as next\n",
    "        - given \"he\" please predict \"l\" next\n",
    "        - given \"hel\" predict \"l\" next\n",
    "        - given \"hell\" predict \"o\" next\n",
    "        \n",
    "        In addition, because the DataLoader will create batches of examples,\n",
    "        every forward/backward pass during traning will simultaneously train\n",
    "        a LOT of predictions, amortizing a lot of computation. In particular,\n",
    "        for a batched input of integers X (B, T) where B is batch size and\n",
    "        T is block_size and Y (B, T), the network will during training be\n",
    "        simultaneously training to make B*T predictions, all at once! Of course,\n",
    "        at test time we can paralellize across batch B, but unlike during training\n",
    "        we cannot parallelize across the time dimension T - we have to run\n",
    "        a forward pass of the network to recover the next single character of the \n",
    "        sequence along each batch dimension, and repeatedly always feed in a next\n",
    "        character to get the next one.\n",
    "        \n",
    "        So yes there is a big asymmetry between train/test time of autoregressive\n",
    "        models. During training we can go B*T at a time with every forward pass,\n",
    "        but during test time we can only go B at a time, T times, with T forward \n",
    "        passes.\n",
    "        \"\"\"\n",
    "        x = torch.tensor(dix[:-1], dtype=torch.long)\n",
    "        y = torch.tensor(dix[1:], dtype=torch.long)\n",
    "        return x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "block_size = 128 # spatial extent of the model for its context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100 1089k  100 1089k    0     0  2915k      0 --:--:-- --:--:-- --:--:-- 2912k\n"
     ]
    }
   ],
   "source": [
    "!curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o input.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data has 1115394 characters, 65 unique.\n"
     ]
    }
   ],
   "source": [
    "# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt\n",
    "from dataset import CharDataset\n",
    "text = open('input.txt', 'r').read() # don't worry we won't run out of file handles\n",
    "train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "01/28/2024 23:02:21 - INFO - mingpt.model -   number of parameters: 2.535219e+07\n"
     ]
    }
   ],
   "source": [
    "from mingpt.model import GPT, GPTConfig\n",
    "mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,\n",
    "                  n_layer=8, n_head=8, n_embd=512)\n",
    "model = GPT(mconf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2179 [00:00<?, ?it/s]/future/u/soumyac/miniconda3/envs/dsp-study/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "epoch 1 iter 2178: train loss 0.34076. lr 3.000169e-04: 100%|██████████| 2179/2179 [13:02<00:00,  2.79it/s]\n",
      "epoch 2 iter 2178: train loss 0.17771. lr 6.000000e-05: 100%|██████████| 2179/2179 [12:56<00:00,  2.81it/s]\n"
     ]
    }
   ],
   "source": [
    "from mingpt.trainer import Trainer, TrainerConfig\n",
    "\n",
    "# initialize a trainer instance and kick off training\n",
    "tconf = TrainerConfig(max_epochs=2, batch_size=512, learning_rate=6e-4,\n",
    "                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,\n",
    "                      num_workers=4)\n",
    "trainer = Trainer(model, train_dataset, None, tconf)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "O God, O God! that e'er this tongue of mine,\n",
      "That laid the sentence of dread banishment\n",
      "On yon proud man, should take it off again\n",
      "With words of sooth! O that I were as great\n",
      "As is my grief, or lesser than my name!\n",
      "Or that I could forget what I have been,\n",
      "Or not remember what I must be now!\n",
      "Swell'st thou, proud heart? I'll give thee scope to beat,\n",
      "Since foes have scope to beat both thee and me.\n",
      "\n",
      "DUKE OF AUMERLE:\n",
      "Northumberland comes back from Bolingbroke.\n",
      "\n",
      "KING RICHARD II:\n",
      "What must the king do now? must he submit?\n",
      "The king shall do it: must he be deposed?\n",
      "The king shall be contented: must he lose\n",
      "The name of king? o' God's name, let it go:\n",
      "I'll give my jewels for a set of beads,\n",
      "My gorgeous palace for a hermitage,\n",
      "My gay apparel for an almsman's gown,\n",
      "My figured goblets for a dish of wood,\n",
      "My sceptre for a palmer's walking staff,\n",
      "My subjects for a pair of carved saints\n",
      "And my large kingdom for a little grave,\n",
      "A little little grave, an obscure grave;\n",
      "Or I'll be buried in the king's highway,\n",
      "Some way of common trade, where subjects' feet\n",
      "May hourly trample on their sovereign's head;\n",
      "For on my heart they tread now whilst I live;\n",
      "And buried once, why not upon my head?\n",
      "Aumerle, thou weep'st, my tender-hearted cousin!\n",
      "We'll make foul weather with despised tears;\n",
      "Our sighs and they shall lodge the summer corn,\n",
      "And make a dearth in this revolting land.\n",
      "Or shall we play the wantons with our woes,\n",
      "And make some pretty match with shedding tears?\n",
      "As thus, to drop them still upon one place,\n",
      "Till they have fretted us a pair of graves\n",
      "Within the earth; and, therein laid,--there lies\n",
      "Two kinsmen digg'd their graves with weeping eyes.\n",
      "Would not this ill do well? Well, well, I see\n",
      "I talk but idly, and you laugh at me.\n",
      "Most mighty prince, my Lord Northumberland,\n",
      "What says King Bolingbroke? will his majesty\n",
      "Give Richard leave to live till Richard die?\n",
      "You make a leg, and Bolingbroke says ay.\n",
      "\n",
      "NORTHUMBERLAND:\n",
      "My lord, in the base court he doth attend\n",
      "To speak with you; may it please you to come dow\n"
     ]
    }
   ],
   "source": [
    "# alright, let's sample some character-level Shakespeare\n",
    "from mingpt.utils import sample\n",
    "\n",
    "context = \"O God, O God!\"\n",
    "x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)\n",
    "y = sample(model, x, 2000, temperature=1.0, sample=True, top_k=10)[0]\n",
    "completion = ''.join([train_dataset.itos[int(i)] for i in y])\n",
    "print(completion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# well that was fun"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
